/***************************************************************************************************
 * Copyright (c) 2017-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Unit tests for threadblock level GEMV
*/

#include "../../common/cutlass_unit_test.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/tensor_ref.h"

#include "cutlass/core_io.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"

#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/gemm.h"

#include "cutlass/gemm/threadblock/gemv.h"
#include "cutlass/gemm/threadblock/default_gemv_core.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace test {
namespace gemm {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Gemv, typename LongIndex, typename RefA, typename RefB,
          typename RefC>
__global__ void batched_gemv_threadblock_test_kernel(
        cutlass::gemm::GemmCoord problem_size, LongIndex stride_a,
        LongIndex stride_b, LongIndex stride_c, RefA ref_A, RefB ref_B,
        RefC ref_C) {
    typename Gemv::IteratorA::TensorCoord threadblock_offset_A(0, 0);
    typename Gemv::IteratorB::TensorCoord threadblock_offset_B(0, 0);
    typename Gemv::IteratorB::TensorCoord threadblock_offset_C(0, 0);

    // Move to the right batches for these threads
    ref_A.add_pointer_offset(threadIdx.y * stride_a);
    ref_B.add_pointer_offset(threadIdx.y * stride_b);
    ref_C.add_pointer_offset(threadIdx.y * stride_c);

    // Construct iterators to A and B operands
    typename Gemv::IteratorA::Params params_A(ref_A.layout());
    typename Gemv::IteratorA iterator_A(params_A, ref_A.data(),
                                        {problem_size.m(), problem_size.k()}, 0,
                                        threadblock_offset_A);
    typename Gemv::IteratorB::Params params_B(ref_B.layout());
    typename Gemv::IteratorB iterator_B(params_B, ref_B.data(),
                                        {problem_size.k(), problem_size.n()},
                                        threadIdx.x, threadblock_offset_B);

    Gemv gemv;

    typename Gemv::FragmentC accum;
    accum.clear();

    // Compute threadblock-scoped matrix multiply-add
    gemv(problem_size, accum, iterator_A, iterator_B, accum);

    // IteratorC is PitchLinear<> assumes n() contiguous
    typename Gemv::IteratorC::Params params_C(ref_C.layout());
    typename Gemv::IteratorC iterator_C(params_C, ref_C.data(),
                                        {problem_size.m(), problem_size.n()},
                                        threadIdx.x, threadblock_offset_C);
    iterator_C.store(accum);
}

/////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Shape_, typename ElementAB_, typename ElementC_,
          typename LayoutA_, typename LayoutB_, typename LayoutC_, int LDG_N,
          int LDG_K, int MAX_THREADS_PER_BLOCK = 512, bool DEBUG = false>
void batched_gemv_threadblock_test(cutlass::gemm::GemmCoord problem_size,
                                   int num_batch) {
    using Shape = Shape_;
    using ElementA = ElementAB_;
    using LayoutA = LayoutA_;
    using ElementB = ElementAB_;
    using LayoutB = LayoutB_;
    using ElementC = ElementC_;
    using LayoutC = LayoutC_;
    using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N, LDG_K>;

    using Core = typename cutlass::gemm::threadblock::DefaultGemvCore<
            Shape, ThreadShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
            LayoutC>;

    if (DEBUG) {
        num_batch = 1;
    }

    using Mma = cutlass::gemm::threadblock::Gemv<Core>;

    // Create host tensors that will be the backing store for the batches
    // Note that no device memory is initially allocated
    cutlass::HostTensor<ElementA, LayoutA> matrix_A(
            {problem_size.m(), problem_size.k()}, false);
    cutlass::HostTensor<ElementB, LayoutB> matrix_B(
            {problem_size.k(), problem_size.n()}, false);
    cutlass::HostTensor<ElementC, LayoutC> matrix_C_computed(
            {problem_size.m(), problem_size.n()}, false);
    cutlass::HostTensor<ElementC, LayoutC> matrix_C_reference(
            {problem_size.m(), problem_size.n()}, false);

    // Reserve memory for the batch of tensors
    matrix_A.reserve(problem_size.m() * problem_size.k() * num_batch);
    matrix_B.reserve(problem_size.n() * problem_size.k() * num_batch);
    matrix_C_computed.reserve(problem_size.m() * problem_size.n() * num_batch);
    matrix_C_reference.reserve(problem_size.m() * problem_size.n() * num_batch,
                               false);

    // Fill eatch tensor batch
    const int seed = 6834;
    for (int b = 0; b < num_batch; b++) {
        if (DEBUG) {
            cutlass::reference::host::BlockFillSequential(
                    matrix_A.host_data_ptr_offset(b * matrix_A.capacity()),
                    matrix_A.capacity());
            cutlass::reference::host::BlockFillSequential(
                    matrix_B.host_data_ptr_offset(b * matrix_B.capacity()),
                    matrix_B.capacity());
        } else {
            cutlass::reference::host::TensorFillRandomUniform(
                    matrix_A.host_view(b * matrix_A.capacity()), seed + 1660, 8,
                    -8, 0);

            cutlass::reference::host::TensorFillRandomUniform(
                    matrix_B.host_view(b * matrix_B.capacity()), seed + 1880, 8,
                    -8, 0);
        }

        cutlass::reference::host::TensorFill(
                matrix_C_computed.host_view(b * matrix_C_computed.capacity()));
        cutlass::reference::host::TensorFill(matrix_C_reference.host_view(
                b * matrix_C_reference.capacity()));
    }

    matrix_A.sync_device();
    matrix_B.sync_device();
    matrix_C_computed.sync_device();

    dim3 grid(1, 1);  // only 1 CTA is used
    dim3 block(Shape::kN / LDG_N, num_batch, 1);

#if 0
  printf("block dim = %d x %d\n", block.x, block.y);
#endif

    // Some sanity checks
    EXPECT_TRUE(problem_size.n() % LDG_N == 0);
    EXPECT_TRUE(block.x * block.y <= MAX_THREADS_PER_BLOCK);

    test::gemm::threadblock::batched_gemv_threadblock_test_kernel<Mma>
            <<<grid, block>>>(problem_size, matrix_A.capacity(),
                              matrix_B.capacity(), matrix_C_computed.capacity(),
                              matrix_A.device_ref(), matrix_B.device_ref(),
                              matrix_C_computed.device_ref());

    cudaError_t result = cudaDeviceSynchronize();
    EXPECT_EQ(result, cudaSuccess)
            << " kernel error: " << cudaGetErrorString(result);

    matrix_C_computed.sync_host();

    // Compute the batched gemms
    for (int b = 0; b < num_batch; b++) {
        cutlass::reference::host::Gemm<ElementA, LayoutA, ElementB, LayoutB,
                                       ElementC, LayoutC, ElementC, ElementC>
                reference_gemm;

        reference_gemm(
                problem_size.mnk(), ElementC(1),
                matrix_A.host_ref(b * matrix_A.capacity()),
                matrix_B.host_ref(b * matrix_B.capacity()), ElementC(0),
                matrix_C_reference.host_ref(b * matrix_C_computed.capacity()));

        bool passed = cutlass::reference::host::TensorEquals(
                matrix_C_computed.host_view(b * matrix_C_computed.capacity()),
                matrix_C_reference.host_view(b *
                                             matrix_C_reference.capacity()));

        EXPECT_TRUE(passed)
                //<< "A:\n" << matrix_A.host_view() << "\n"
                //<< "B:\n" << matrix_B.host_view() << "\n"
                << "Batch: " << b << "\n"
                << "Reference:\n"
                << matrix_C_reference.host_view(b *
                                                matrix_C_reference.capacity())
                << "\n"
                << "Computed:\n"
                << matrix_C_computed.host_view(b * matrix_C_computed.capacity())
                << "\n";
    }
}

}  // namespace threadblock
}  // namespace gemm
}  // namespace test

/////////////////////////////////////////////////////////////////////////////////////////////////

// A: ColumnMajor
// B: RowMajor
// C: ColumnMajor

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp32_fp32_2N_2K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 2;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 5x1x128x128_crc_fp32_fp32_4N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 128, 128);
    const int num_batch = 5;
    const int LDG_N = 4;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp32_fp32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_2K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 2;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_fp16_fp32_2N_8K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 8;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_fp16_fp32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_crc_i8_i32_2N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
    batched_gemv_threadblock_test<Shape, int8_t, int32_t,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_crc_i8_i32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, int8_t, int32_t,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

// A: RowMajor
// B: ColumnMajor
// C: RowMajor

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp32_fp32_2N_2K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 2;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcr_fp32_fp32_4N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 128, 128);
    const int num_batch = 5;
    const int LDG_N = 4;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp32_fp32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_2K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 2;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_fp16_fp32_2N_8K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 8;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_fp16_fp32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcr_i8_i32_2N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
    batched_gemv_threadblock_test<Shape, int8_t, int32_t,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcr_i8_i32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, int8_t, int32_t,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::RowMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

// A: RowMajor
// B: ColumnMajor
// C: ColumnMajor

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp32_fp32_2N_2K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 2;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 5x1x128x128_rcc_fp32_fp32_4N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 128, 128);
    const int num_batch = 5;
    const int LDG_N = 4;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp32_fp32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, float, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_2K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 2;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_fp16_fp32_2N_8K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 8;

    using Shape = cutlass::gemm::GemmShape<1, 64, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_fp16_fp32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, cutlass::half_t, float,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 4x1x64x64_rcc_i8_i32_2N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 64, 64);
    const int num_batch = 4;
    const int LDG_N = 2;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 128, LDG_K>;
    batched_gemv_threadblock_test<Shape, int8_t, int32_t,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}

TEST(SM50_batched_gemv_threadblock, 16x1x17x64_rcc_i8_i32_1N_4K) {
    using namespace test::gemm::threadblock;
    cutlass::gemm::GemmCoord problem_size(1, 17, 64);
    const int num_batch = 16;
    const int LDG_N = 1;
    const int LDG_K = 4;

    using Shape = cutlass::gemm::GemmShape<1, 32, LDG_K>;
    batched_gemv_threadblock_test<Shape, int8_t, int32_t,
                                  cutlass::layout::RowMajor,
                                  cutlass::layout::ColumnMajor,
                                  cutlass::layout::ColumnMajor, LDG_N, LDG_K>(
            problem_size, num_batch);
}
